import numpy as np
import paddle

from QNN_model import train_model


def fetch_data_random_seed_val(n_samples, seed=0):
    import pandas as pd
    from sklearn import preprocessing
    from sklearn.model_selection import train_test_split

    dataset = pd.read_csv('./data/pulsar.csv')

    data0 = dataset[dataset[dataset.columns[8]] == 0]
    data0 = data0.sample(n=n_samples, random_state=seed)
    X0 = data0[data0.columns[0:8]].values
    Y0 = data0[data0.columns[8]].values

    data1 = dataset[dataset[dataset.columns[8]] == 1]
    data1 = data1.sample(n=n_samples, random_state=seed)
    X1 = data1[data1.columns[0:8]].values
    Y1 = data1[data1.columns[8]].values

    X = np.append(X0, X1, axis=0)
    Y = np.append(Y0, Y1, axis=0)

    min_max_scaler = preprocessing.MinMaxScaler(feature_range=(0, np.pi))
    X = min_max_scaler.fit_transform(X)

    # Separate the test and training datasets
    train_X, test_X, train_y, test_y = train_test_split(X, Y, test_size=0.2, random_state=seed)

    train_X = paddle.to_tensor(train_X)
    test_X = paddle.to_tensor(test_X)
    train_y = paddle.to_tensor(train_y, dtype="float64")
    test_y = paddle.to_tensor(test_y,dtype="float64")

    return train_X[:80], test_X[:20], train_y[:80], test_y[:20]


if __name__ == '__main__':
    data_size = 100
    block = 4
    depth = 2
    num_qubit = 8
    encoding = 'angle_encoding'
 
    # obtain pulsar data
    train_X, test_X, train_y, test_y = fetch_data_random_seed_val(n_samples=data_size//2, seed=0)

    print(train_X.shape)
    print(train_y.shape)
    print(test_X.shape)
    print(test_y.shape)

    num_samples = 10
    for seed in range(num_samples):
        train_loss, test_accuracy = train_model(train_X, train_y, test_X, test_y, seed=seed, N=num_qubit, DEPTH=depth, BLOCK=block, EPOCH=10, BATCH_SIZE=40, LR=0.1)
        np.save('extension/classification/output/pulsar/train_loss_%s_%s_%s_%s.npy'%(num_qubit, block, depth, seed), train_loss)
        np.save('extension/classification/output/pulsar/test_accuracy_%s_%s_%s_%s.npy'%(num_qubit, block, depth, seed), test_accuracy)
        print('----------------------------------------------------------------------------------sample: %s/%s  finished'%(seed+1, num_samples))